import torch
# 参数设置
batch_size = 1 # 批量大小
seq_len = 3 # 样本数量
input_size = 4 # 输入维度(特征)
hidden_size = 2 # 隐藏层(输出)维度

# input.shape = (batch_size, input_size)
# output.shape = (batch_size, hidden_size)
# dataset.shape = (seqlen, batch_size, input_size)

cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)

datasets = torch.randn(seq_len, batch_size, input_size)
hidden = torch.randn(batch_size, hidden_size) # 初始化零向量

for idx, input in enumerate(datasets):
    print("=" * 20, idx, "=" * 20)
    print("Input size:", input.shape)

    hidden = cell(input, hidden)
    print("Hidden size:", hidden.shape)
    print(hidden, "\n")


